import dataclasses
import torch
import torch.nn.functional as F

import megatron.training.arguments
from megatron.training.activations import squared_relu
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
moe_freq_type = megatron.training.arguments.moe_freq_type

def _add_moe_args(parser):
    group = parser.add_argument_group(title="moe")
    # General arguments
    group.add_argument('--expert-model-parallel-size', type=int, default=1,
                       help='Degree of expert model parallelism.')
    group.add_argument('--expert-tensor-parallel-size', type=int, default=None,
                       help='Degree of expert model parallelism. Default is None, which will be set to the value of --tensor-model-paralle-size.')
    group.add_argument('--num-experts', type=int, default=None,
                       help='Number of Experts in MoE (None means no MoE)')
    group.add_argument('--moe-layer-freq', type=moe_freq_type, default=1,
                       help='Frequency between MoE layers and Dense layers. Accepts either: '
                            '- An integer N: Represents a 1:N ratio, meaning one expert layer for every N-1 dense layers '
                            '- A string containing a Python list expression that defines a custom pattern, e.g.: '
                            '"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] '
                            'where 1 indicates an expert layer and 0 indicates a dense layer. '
                            'Examples: "([0]+[1]*23)": 1 dense layer followed by 23 experts layers, '
                            '"([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice.')
    group.add_argument('--moe-ffn-hidden-size', type=int, default=None,
                       help='The hidden size of each expert\'s feed-forward network (ffn). '
                       'If not specified, defaults to the ffn_hidden_size.')
    group.add_argument('--moe-shared-expert-intermediate-size', type=int, default=None,
                       help='Shared expert total ffn hidden size. '
                       'It should be equal to "num_shared_experts * ffn_size_of_each_shared_expert" if there are multiple shared experts. '
                       'None means no shared expert.')
    group.add_argument('--moe-shared-expert-overlap', action='store_true',
                       help='Enable overlapping between shared expert computations and dispatcher communications. '
                       'Without this, the shared epxerts execute after the routed experts. '
                       'Only effective when moe-shared-expert-intermediate-size is set.')
    group.add_argument('--moe-grouped-gemm', action='store_true',
                       help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
    # Router arguments
    group.add_argument('--moe-router-load-balancing-type', type=str,
                       choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
                       default='aux_loss',
                       help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
    group.add_argument('--moe-router-score-function', type=str,
                       choices=['softmax', 'sigmoid'],
                       default='softmax',
                       help='Score function for MoE TopK routing. Can be "softmax" or "sigmoid".')
    group.add_argument('--moe-router-topk', type=int, default=2,
                       help='Number of experts to route to for each token. The default is 2.')
    group.add_argument('--moe-router-pre-softmax', action='store_true',
                       help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
    group.add_argument('--moe-router-num-groups', type=int, default=None,
                       help='Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen.'
                       'Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)')
    group.add_argument('--moe-router-group-topk', type=int, default=None,
                       help='Number of selected groups for group-limited routing.')
    group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
                       help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
    group.add_argument('--moe-router-enable-expert-bias', action='store_true',
                       help='TopK routing with dynamic expert bias in the aux-loss-free load balancing strategy. '
                       'The routing decision is based on the sum of the routing scores and the expert bias. '
                       'See https://arxiv.org/abs/2408.15664 for details.')
    group.add_argument('--moe-router-bias-update-rate', type=float, default=1e-3,
                       help='Expert bias update rate in the aux-loss-free load balancing strategy. '
                       'The expert bias is updated based on the number of assigned tokens to each expert in a global batch, '
                       'where the bias is increased for the experts with less assigned tokens and decreased for the experts with more assigned tokens. '
                       'The default value 1e-3 is same as that used in DeepSeekV3.')
    group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
                       help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
    group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
                       help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
    group.add_argument('--moe-z-loss-coeff', type=float, default=None,
                       help='Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.')
    group.add_argument('--moe-input-jitter-eps', type=float, default=None,
                       help='Add noise to the input tensor by applying jitter with a specified epsilon value.')
    group.add_argument('--moe-token-dispatcher-type', type=str,
                       choices=['allgather', 'alltoall', 'flex', 'alltoall_seq'],
                       default='allgather',
                       help="The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather', 'alltoall' and 'alltoall_seq'. We recommend using 'alltoall' when applying expert parallelism. For more information, please refer to the documentation in core/moe/README.")
    group.add_argument('--moe-enable-deepep', action='store_true',
                       help='[Experimental] Enable DeepSeek/DeepEP for efficient token dispatching and combine in MoE models. Only works with flex token dispatcher by setting --moe-token-dispatcher-type=flex.')
    group.add_argument('--moe-per-layer-logging', action='store_true',
                       help='Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.')
    # Token dropping arguments
    group.add_argument('--moe-expert-capacity-factor', type=float, default=None,
                       help='The capacity factor for each expert, None means no token will be dropped.')
    group.add_argument('--moe-pad-expert-input-to-capacity', action='store_true',
                       help='Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set.')
    group.add_argument('--moe-token-drop-policy', type=str, default='probs', choices=['probs', 'position'],
                       help='The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.')
    group.add_argument('--moe-layer-recompute', action='store_true',
                       help='Enable checkpointing for moe_layer, should be used when memory is not sufficient.')
    group.add_argument('--moe-extended-tp', action='store_true',
                       help='Deprecated. Use --expert-tensor-parallel-size instead.')
    group.add_argument('--moe-use-upcycling', action='store_true',
                       help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. '
                       'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.')
    group.add_argument('--moe-permute-fusion', action='store_true',
                       help='Fuse token rearrangement ops during token dispatching.')
    
    # HACK(huang.huang): control dp_reduce position: tp-only-amax-red 
    group.add_argument('--tp-only-amax-red', action='store_true',
                        help="Whether to reduce the FP8 AMAX only in the TP or TP-CP domain") 
    ## HACK(huang.huang)

    # HACK(yehua.zhang): add dsv2 & dsv3 loss, q-rms-recompute
    # dsv2
    group.add_argument('--moe-device-level-aux-loss-coeff', type=float, default=None,
                       help='Scaling coefficient for the device-level aux loss')
    group.add_argument('--moe-comm-aux-loss-coeff', type=float, default=None,
                       help='Scaling coefficient for the communication aux loss')
    group.add_argument('--moe-device-level-capacity', action='store_true',
                       help='Whether to consider the expert capacity of a group together')
    
    # dsv3
    group.add_argument('--moe-complementary-seq-aux-loss', action='store_true',
                       help='use complementary sequence-wise aux loss in MoE, should only used with seq_aux_loss')
    group.add_argument('--moe-router-norm-topk-prob', action='store_true',
                       help='Enable normalization for sigmoid score in MoE, should only used with moe-router-use-sigmoid')

    # q-rms-recompute
    group = parser.add_argument_group(title="mla")
    group.add_argument('--q-rms-recompute', action='store_true',
                       help="use q uproj rmsnorm recompute")
    ## HACK(yehua.zhang)

    # HACK(huang.huang): add attn-recompute, recompute-variance, groupMLP_recompute
    group.add_argument('--attn-recompute', action='store_true',
                       help="use attn recompute")
    group.add_argument('--mla-rms-recompute', action='store_true',
                       help="use rms recompute before mla")
    group.add_argument('--mlp-rms-recompute', action='store_true',
                       help="use rms recompute before mlp")
    group.add_argument('--recompute-variance', action='store_true',
                       help="use recompute variance")
    group.add_argument('--mlp-recompute', action='store_true',
                       help="use groupMLP_recompute to recompute groupgemm and shared_exp in moelayer, mlp in dense") 
    ## HACK(huang.huang)
    return parser


def core_transformer_config_from_args(args, config_class=None):

    # Config class.
    config_class = config_class or TransformerConfig

    if args.multi_latent_attention:
        config_class = MLATransformerConfig

    # Translate args to core transformer configuration
    kw_args = {}
    for f in dataclasses.fields(config_class):
        if hasattr(args, f.name):
            kw_args[f.name] = getattr(args, f.name)
    kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
    kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
    kw_args['layernorm_epsilon'] = args.norm_epsilon
    kw_args['deallocate_pipeline_outputs'] = True
    kw_args['pipeline_dtype'] = args.params_dtype
    kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
    kw_args['num_moe_experts'] = args.num_experts
    kw_args['rotary_interleaved'] = args.rotary_interleaved
    kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers
    kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers
    if args.swiglu:
        kw_args['activation_func'] = F.silu
        kw_args['gated_linear_unit'] = True
        kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion
    else:
        kw_args['bias_activation_fusion'] = args.bias_gelu_fusion
    if args.squared_relu:
        assert not args.swiglu
        kw_args['activation_func'] = squared_relu
    if args.init_method_xavier_uniform:
        kw_args['init_method'] = torch.nn.init.xavier_uniform_
        kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
    if args.group_query_attention:
        kw_args['num_query_groups'] = args.num_query_groups
    else:
        kw_args['num_query_groups'] = None
    kw_args['config_logger_dir'] = args.config_logger_dir

    if len(args.cp_comm_type) == 1:
        kw_args['cp_comm_type'] = args.cp_comm_type[0]
    
    # Return config.
    
    # HACK(yehua.zhang): add dsv2 & dsv3 loss, mtp, q-rms-recompute from args to transformer config
    config_instance = config_class(**kw_args)

    config_instance.moe_device_level_aux_loss_coeff = args.moe_device_level_aux_loss_coeff
    config_instance.moe_comm_aux_loss_coeff = args.moe_comm_aux_loss_coeff
    config_instance.moe_device_level_capacity = args.moe_device_level_capacity

    config_instance.moe_complementary_seq_aux_loss = args.moe_complementary_seq_aux_loss
    config_instance.moe_router_norm_topk_prob = args.moe_router_norm_topk_prob
    config_instance.moe_device_level_capacity = args.moe_device_level_capacity

    config_instance.q_rms_recompute = args.q_rms_recompute
    ## HACK(yehua.zhang)

    # HACK(huang.huang): add attn-recompute, recompute-variance, mlp_recompute
    config_instance.attn_recompute = args.attn_recompute
    config_instance.mla_rms_recompute = args.mla_rms_recompute
    config_instance.mlp_rms_recompute = args.mlp_rms_recompute
    config_instance.recompute_variance = args.recompute_variance
    config_instance.mlp_recompute = args.mlp_recompute
    ## HACK(huang.huang)

    # HACK(huang.huang): args check for pp=1 and first/last stage num layer=None
    if config_instance.pipeline_model_parallel_size == 1:
        assert config_instance.num_layers_in_first_pipeline_stage is None and config_instance.num_layers_in_last_pipeline_stage is None, \
            f"For pipeline_model_parallel_size=1, first/last must be None, but get {config_instance.num_layers_in_first_pipeline_stage}/{config_instance.num_layers_in_last_pipeline_stage}"
    ## HACK(huang.huang)

    # HACK(huang.huang): control dp_reduce position: tp-only-amax-red 
    config_instance.tp_only_amax_red = args.tp_only_amax_red
    ##HACK(huang.huang)
    
    print('config_instance is ', config_instance)
    return config_instance


megatron.training.arguments._add_moe_args = _add_moe_args
megatron.training.arguments.core_transformer_config_from_args = core_transformer_config_from_args
